package org.zjvis.datascience.service.dataset;

import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.exceptions.ExceptionUtil;
import cn.hutool.core.util.RandomUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.db.Db;
import cn.hutool.db.Entity;
import cn.hutool.db.ds.simple.SimpleDataSource;
import cn.hutool.db.meta.MetaUtil;
import cn.hutool.db.meta.Table;
import cn.hutool.db.sql.Wrapper;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.zjvis.datascience.common.constant.DatabaseConstant;
import org.zjvis.datascience.common.dto.DataConfigDTO;
import org.zjvis.datascience.common.dto.DatasetDTO;
import org.zjvis.datascience.common.dto.PendingDatasetDTO;
import org.zjvis.datascience.common.dto.dataset.DatasetConfigInfo;
import org.zjvis.datascience.common.dto.dataset.DatasetJsonInfo;
import org.zjvis.datascience.common.dto.dataset.DatasetScheduleDTO;
import org.zjvis.datascience.common.enums.ImportTaskStatusEnum;
import org.zjvis.datascience.common.exception.BaseErrorCode;
import org.zjvis.datascience.common.exception.DataScienceException;
import org.zjvis.datascience.common.pool.BasePool;
import org.zjvis.datascience.common.util.JwtUtil;
import org.zjvis.datascience.common.util.RedisUtil;
import org.zjvis.datascience.common.util.SqlUtil;
import org.zjvis.datascience.common.util.db.JDBCUtil;
import org.zjvis.datascience.common.vo.dataset.BatchImportDatasetProgressVO;
import org.zjvis.datascience.common.vo.dataset.DatasetDBTableVO;
import org.zjvis.datascience.common.vo.dataset.ImportDatasetProgressVO;
import org.zjvis.datascience.common.vo.db.DBImportVO;
import org.zjvis.datascience.service.dataprovider.GPDataProvider;

import javax.annotation.PostConstruct;
import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;

/**
 * @description 数据上传Service
 * @date 2021-12-03
 */
@Service
public class ImportDataService {
    private final static Logger logger = LoggerFactory.getLogger("ImportDataService");

    @Autowired
    private BasePool threadPool;
    @Autowired
    private RedisUtil redisUtil;
    @Autowired
    private GPDataProvider gpDataProvider;
    @Autowired
    private JDBCUtil jdbcUtil;
    @Autowired
    private DatasetService datasetService;

    private static final Object CACHE_MODIFY_LOCK = new Object();

    /**
     * 恢复任务时，只恢复正在进行中的任务，更换目标表重新导入
     */
    @PostConstruct
    private void resumeTask() {
        Set<Object> keys = redisUtil.hKeys(DatabaseConstant.DATA_IMPORT_TASK_POOL);
        if (CollectionUtil.isEmpty(keys)) {
            return;
        }
        keys.forEach(id -> resumeTask(id.toString()));
    }

    private void resumeTask(String uid) {
        Map<String, PendingDatasetDTO> tasks = getTask(uid);
        if (CollectionUtil.isEmpty(tasks)) {
            rmTask(Long.parseLong(uid));
            return;
        }
        renewCache(tasks, uid);
        resumeTask(tasks);
    }

    private void resumeTask(Map<String, PendingDatasetDTO> tasks) {
        String sampleTable = tasks.keySet().stream().findFirst().get();
        PendingDatasetDTO sampleTask = tasks.get(sampleTable);
        DataSource fromDs = JDBCUtil.getDataSource(sampleTask.getUrl(), sampleTask.getUser(), sampleTask.getPassword());
        executeImport(fromDs, tasks, sampleTask.getUserId(), sampleTask.getDatabaseType());
    }

    private void renewCache(Map<String, PendingDatasetDTO> tasks, String uid) {
        for (Map.Entry<String, PendingDatasetDTO> entry : tasks.entrySet()) {
            PendingDatasetDTO t = entry.getValue();
            if (ImportTaskStatusEnum.PENDING.equals(t.getStatus())) {
                renewTask(t);
            }
        }
        // 避免任务被取消或已完成后，又被写入这里的旧数据
        synchronized (CACHE_MODIFY_LOCK) {
            if (existTask(uid)) {
                recordTask(uid, tasks);
            }
        }
    }

    private boolean existTask(String uid) {
        return redisUtil.hHasKey(DatabaseConstant.DATA_IMPORT_TASK_POOL, uid);
    }

    /**
     * 恢复任务时重新生成新的目标表，并更新相关信息
     *
     * @param task
     */
    private void renewTask(PendingDatasetDTO task) {
        String newTargetTable = generateGpTableName();

        DatasetDTO ds = task.getDataset();
        DataInfo di = JSON.parseObject(ds.getDataJson(), DataInfo.class);
        di.setTable(newTargetTable);
        ds.setDataJson(JSON.toJSONString(di));

        task.setTargetTableName(newTargetTable);
        task.setStartTime(System.currentTimeMillis());
    }

    /**
     * 缓存导入任务
     *
     * @param uid  用户id
     * @param data key表名
     */
    public void recordTask(long uid, Map<String, PendingDatasetDTO> data) {
        recordTask("" + uid, data);
    }

    public void recordTask(String uid, Map<String, PendingDatasetDTO> data) {
        redisUtil.hset(DatabaseConstant.DATA_IMPORT_TASK_POOL, "" + uid, data);
    }

    public void rmTask() {
        rmTask(JwtUtil.getCurrentUserId());
    }

    public void rmTask(long uid) {
        redisUtil.hdel(DatabaseConstant.DATA_IMPORT_TASK_POOL, "" + uid);
    }

    /**
     * 获取缓存任务
     *
     * @param uid 用户id
     */
    public Map<String, PendingDatasetDTO> getTask(String uid) {
        Object obj = redisUtil.hget(DatabaseConstant.DATA_IMPORT_TASK_POOL, uid);
        if (obj == null) {
            return null;
        }
        return (Map<String, PendingDatasetDTO>) obj;
    }

    public Map<String, PendingDatasetDTO> getTask(long uid) {
        return getTask(uid + "");
    }

    public void importDataCheck() {
        long uid = JwtUtil.getCurrentUserId();

        Map<String, PendingDatasetDTO> tasks = getTask(uid);
        clearIfAllFinish(tasks, uid);
        synchronized (this) {
            if (existTask(uid + "")) {
                throw new DataScienceException(BaseErrorCode.DATASET_IMPORT_TASK_EXIST);
            }
        }
    }

    public void importDataset(DBImportVO vo) {
        importDataCheck();
        doImportDataset(vo);
    }

    public void doImportDataset(DBImportVO vo) {
        String dbType = vo.getDatabaseType();
        dbType = dbType == null ? "" : dbType.trim().toLowerCase();

        switch (dbType) {
            case "mysql":
            case "oracle":
                break;
            case "rds-mysql":
                //后续rds-mysql的处理和mysql一样，所以此处转为mysql
                dbType = "mysql";
                break;
            default:
                throw new DataScienceException(BaseErrorCode.DATASET_PREVIEW_UNSUPPORTED);
        }
        String url = JDBCUtil.getUrl(vo.getServer(), vo.getPort(), vo.getDatabaseName(), dbType,
                vo.getConnectType(), vo.getConnectValue());
        DataSource fromDs = JDBCUtil.getDataSource(url, vo.getUser(), vo.getPassword());

        try {
            Map<String, PendingDatasetDTO> tasks = recordTask(vo, fromDs);
            executeImport(fromDs, tasks, JwtUtil.getCurrentUserId(), dbType);
        } catch (SQLException e) {
            rmTask();
            throw new DataScienceException(BaseErrorCode.DATASET_IMPORT_ERROR, e);
        }
    }

    private void executeImport(DataSource fromDs, Map<String, PendingDatasetDTO> tasks, long uid, String sourceDbType) {
        for (String t : tasks.keySet()) {
            CompletableFuture.runAsync(() -> {
                if (!ImportTaskStatusEnum.PENDING.equals(tasks.get(t).getStatus())) {
                    return;
                }
                DatasetDTO dataset = tasks.get(t).getDataset();
                DatasetJsonInfo dj = JSONObject.parseObject(dataset.getDataJson(), DatasetJsonInfo.class);
                executeImport(t, fromDs, tasks.get(t).getTargetTableName(), tasks.get(t).getCount(), uid, sourceDbType, dj, true, "-1", null);
            }, threadPool.getExecutor());
        }
    }

    public void executeImport(String t, DataSource fromDs, String targetTable, int count, long uid, String sourceDbType, DatasetJsonInfo dj, boolean firstImport, String lastValue, String incrementColumn) {
        String err = null;
        try {
            importData(fromDs, t, targetTable, count, sourceDbType, dj, lastValue, incrementColumn, firstImport);
        } catch (Exception e) {
            logger.error(e.getMessage(), e);
            err = ExceptionUtil.stacktraceToString(e);
        }
        if (firstImport) {
            importTaskFinish(uid, t, err);
        }
    }

    public Map<String, PendingDatasetDTO> recordTask(DBImportVO vo, DataSource ds) throws SQLException {
        Map<String, PendingDatasetDTO> tasks = new HashMap<>(16);
        long uid = JwtUtil.getCurrentUserId();
        for (DatasetDBTableVO tb : vo.getTableConfigs()) {
            PendingDatasetDTO pendingDatasetDTO = buildPendingInfo(ds, tb.getTableName(), vo, uid, tb);
            tasks.put(tb.getTableName(), pendingDatasetDTO);
        }
        recordTask(uid, tasks);
        return tasks;
    }

    public PendingDatasetDTO buildPendingInfo(DataSource ds, String table, DBImportVO dbInfo, long uid, DatasetDBTableVO tb) throws SQLException {
        String targetTableName = generateGpTableName();

        DatasetJsonInfo dj = DatasetJsonInfo.builder()
                .schema(DatabaseConstant.GREEN_PLUM_DEFAULT_SCHEMA)
                .table(targetTableName)
                .type(dbInfo.getImportType())
                .columnMessage(tb.getData())
                .build();

        DataConfigDTO dataConfigDTO = DataConfigDTO.builder()
                .databaseType(dbInfo.getDatabaseType())
                .server(dbInfo.getServer())
                .port(dbInfo.getPort())
                .user(dbInfo.getUser())
                .password(dbInfo.getPassword())
                .databaseName(dbInfo.getDatabaseName())
                .tableName(table)
                .url(((SimpleDataSource) ds).getUrl())
                .build();

        DatasetScheduleDTO datasetScheduleDTO = tb.getDatasetScheduleConfig();
        DatasetConfigInfo datasetConfigInfo;
        if (datasetScheduleDTO == null || datasetScheduleDTO.getNeedSchedule() == null || !datasetScheduleDTO.getNeedSchedule()) {
            datasetConfigInfo = DatasetConfigInfo.builder()
                    .needSchedule(false)
                    .build();
        } else {
            String incrementalColumn = datasetScheduleDTO.getIncrementColumn();
            if (dbInfo.getDatabaseType().equals("oracle") && incrementalColumn.equals("_rowid")) {
                incrementalColumn = "rowid";
            }
            datasetConfigInfo = DatasetConfigInfo.builder()
                    .needSchedule(true)
                    .crontab(datasetService.getCrontab(datasetScheduleDTO))
                    .incrementColumn(incrementalColumn)
                    .lastValue(datasetService.getMaxValue(ds, table, datasetScheduleDTO.getIncrementColumn()))
                    .detail(datasetScheduleDTO)
                    .build();
        }
        DatasetDTO datasetDTO = datasetService.buildDataset(uid, dbInfo.getCategoryId(), table, dj, dataConfigDTO, datasetConfigInfo);
        return buildPendingInfo(ds, datasetDTO, uid, targetTableName, dbInfo);
    }

    public PendingDatasetDTO buildPendingInfo(DataSource ds, DatasetDTO datasetDTO, long uid, String targetTableName
            , DBImportVO dbInfo) throws SQLException {
        Db db = Db.use(ds);
        switch (dbInfo.getDatabaseType().toLowerCase()) {
            case "oracle":
                Wrapper oldw = db.getRunner().getDialect().getWrapper();
                oldw.setPreWrapQuote('\"');
                oldw.setSufWrapQuote('\"');
                break;
        }

        int count = db.count(new Entity(datasetDTO.getName()));

        return PendingDatasetDTO.builder()
                .count(count)
                .databaseType(dbInfo.getDatabaseType())
                .tableName(datasetDTO.getName())
                .dataset(datasetDTO)
                .userId(uid)
                .url(JDBCUtil.getUrl(dbInfo.getServer(), dbInfo.getPort(), dbInfo.getDatabaseName(),
                        dbInfo.getDatabaseType(), dbInfo.getConnectType(), dbInfo.getConnectValue()))
                .user(dbInfo.getUser())
                .password(dbInfo.getPassword())
                .targetTableName(targetTableName)
                .startTime(System.currentTimeMillis())
                .status(ImportTaskStatusEnum.PENDING)
                .build();
    }

    public void importTaskFinish(long uid, String table, String err) {
        synchronized (CACHE_MODIFY_LOCK) {
            Map<String, PendingDatasetDTO> tasks = getTask(uid);
            if (tasks == null) {
                logger.error("dataset import task pool doesn't have this user's task. Maybe has been canceled. uid:" + uid);
                return;
            }

            PendingDatasetDTO t = tasks.get(table);
            if (t == null) {
                logger.error("dataset import task pool doesn't have this task. tableName:" + table);
                return;
            }

            // 一般出现这种情况是因为任务已被取消
            if (!ImportTaskStatusEnum.PENDING.equals(t.getStatus())) {
                return;
            }

            t.setStatus(err == null ? ImportTaskStatusEnum.SUCCESS : ImportTaskStatusEnum.FAIL);
            t.setErrorInfo(err);
            t.setCostTime(System.currentTimeMillis() - t.getStartTime());

            logger.info("import data task finish: ", t);
            tasks.put(table, t);
            finishIfSuccess(t);

            recordTask(uid, tasks);
        }
    }

    private void finishIfSuccess(PendingDatasetDTO pendingDataset) {
        if (pendingDataset.getErrorInfo() != null) {
            return;
        }

        datasetService.insert(pendingDataset.getDataset());
    }

    public Map<String, PendingDatasetDTO> getTasks() {
        return getTask(JwtUtil.getCurrentUserId());
    }

    /**
     * 取消任务时，将表进度的key删除，任务执行处JDBCUtil#writeToTable判断如果无表进度，则不进行任务
     */
    public void cancel() {
        Map<String, PendingDatasetDTO> tasks = recordCancel();
        if (CollectionUtil.isEmpty(tasks)) {
            return;
        }

        for (Map.Entry<String, PendingDatasetDTO> entry : tasks.entrySet()) {
            PendingDatasetDTO t = entry.getValue();
            JDBCUtil.IMPORT_TASK_PROGRESS.remove(t.getTargetTableName());
        }
    }

    /**
     * 任务如果都已结束，则抛异常; 如果无任务则返回null，不抛异常
     */
    private Map<String, PendingDatasetDTO> recordCancel() {
        long user = JwtUtil.getCurrentUserId();
        synchronized (CACHE_MODIFY_LOCK) {
            Map<String, PendingDatasetDTO> tasks = getTask(user);
            if (CollectionUtil.isEmpty(tasks)) {
                return null;
            }
            boolean hasPending = false;
            for (Map.Entry<String, PendingDatasetDTO> entry : tasks.entrySet()) {
                if (ImportTaskStatusEnum.PENDING.equals(entry.getValue().getStatus())) {
                    hasPending = true;
                    entry.getValue().setStatus(ImportTaskStatusEnum.CANCEL);
                }
            }
            if (hasPending) {
                recordTask(user, tasks);
                return tasks;
            }
            throw new DataScienceException(BaseErrorCode.DATASET_IMPORT_CANCEL_ERROR);
        }
    }

    /**
     * 异步导入数据任务进度查询
     */
    public BatchImportDatasetProgressVO getProgress() {
        long uid = JwtUtil.getCurrentUserId();
        Map<String, PendingDatasetDTO> tasks = getTask(uid);
        if (CollectionUtil.isEmpty(tasks)) {
            return null;
        }

        Collection<PendingDatasetDTO> taskDtos = tasks.values();

        return BatchImportDatasetProgressVO.builder()
                // 提前清除，避免查询失败导致无法清除
                .finalFlag(clearIfAllFinish(tasks, uid))
                .progresses(taskDtos.stream().map(this::toProgress).collect(Collectors.toList()))
                .build();
    }

    private ImportDatasetProgressVO toProgress(PendingDatasetDTO ds) {
        return ImportDatasetProgressVO.builder()
                .err(ds.getErrorInfo())
                .table(ds.getTableName())
                .count(ds.getCount())
                .status(ds.getStatus().getDesc())
                .progress(computeProgress(ds))
                .costTime(ds.getCostTime())
                .datasetId(ds.getDataset().getId())
                .build();
    }

    private int computeProgress(PendingDatasetDTO ds) {
        if (ds.getCount() == 0 || ImportTaskStatusEnum.SUCCESS.equals(ds.getStatus())) {
            return DatabaseConstant.DATA_IMPORT_PROGRESS_FINISH;
        }

        try {
            int finishCount = JDBCUtil.IMPORT_TASK_PROGRESS.getOrDefault(ds.getTargetTableName(), 0);
            return (int) (((long) finishCount * 10000) / ds.getCount());
        } catch (Exception e) {
            return DatabaseConstant.DATA_IMPORT_PROGRESS_INIT;
        }
    }

    public boolean clearIfAllFinish(Map<String, PendingDatasetDTO> tasks, long uid) {
        if (CollectionUtil.isEmpty(tasks)) {
            return true;
        }
        for (PendingDatasetDTO t : tasks.values()) {
            if (ImportTaskStatusEnum.PENDING.equals(t.getStatus())) {
                return false;
            }
        }
        rmTask(uid);
        return true;
    }

    public void importData(DataSource fromDs, String table, String targetTable, int size, String sourceDbType, DatasetJsonInfo dj, String lastValue, String incrementColumn, boolean firstImport) throws Exception {
        logger.info(String.format("import data. from %s, table: %s, to: %s", fromDs.getConnection().getMetaData(), table, targetTable));

        try (Connection toConn = gpDataProvider.getConn(DatabaseConstant.GREEN_PLUM_DATASET_ID)) {
            Table tbMeta = createTable(fromDs, toConn, table, targetTable, DatabaseConstant.GREEN_PLUM_DEFAULT_SCHEMA, sourceDbType, dj, firstImport);
            jdbcUtil.writeToTable(fromDs, tbMeta, gpDataProvider.getSource(DatabaseConstant.GREEN_PLUM_DATASET_ID)
                    , targetTable, size, sourceDbType, dj.getColumnMessage(), lastValue, incrementColumn);

        }
    }

    public Table createTable(DataSource fromDs, Connection to, String table, String targetTable, String schema, String sourceDbType, DatasetJsonInfo dj, boolean firstImport) {
        Table tableMeta = MetaUtil.getTableMeta(fromDs, table);
        String originName = tableMeta.getTableName();
        tableMeta.setTableName(schema + "." + targetTable);

        //不需要导入字段移除、脱敏类型转换等特殊处理
        tableMeta = SqlUtil.columnSpecialProcess(tableMeta, dj);

        switch (sourceDbType.toLowerCase()) {
            case "oracle":
                tableMeta = SqlUtil.changeOracleSchemaToPGSql(tableMeta);
                break;
            default:
                tableMeta = SqlUtil.changeMysqlSchemaToPGSql(tableMeta);
                break;
        }
        if (firstImport) {
            String sql = SqlUtil.generateCreateTableSqlForGP(tableMeta);
            if (sql == null) {
                throw DataScienceException.of(BaseErrorCode.DATASET_TABLE_NO_FIELD, table);
            }
            logger.info("import data create table sql:{}" + sql);
            try {
                JDBCUtil.execute(to, sql);
            } catch (SQLException e) {
                throw new DataScienceException(BaseErrorCode.DATASET_TABLE_GP_CREATE_FAIL, e);
            } finally {
                JDBCUtil.close(to, null, null);
            }
        }

        tableMeta.setTableName(originName);
        return tableMeta;
    }

    public static String generateGpTableName() {
        return StrUtil.format("{}_{}_{}",
                RandomUtil.randomChar(RandomUtil.BASE_CHAR), JwtUtil.getCurrentUserId(), System.currentTimeMillis());
    }
}

